Skip to content

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented Aug 19, 2022

Add sym_sizes with dynamic shape support

@miladm miladm requested review from JackCaoG and Krovatkin August 19, 2022 20:52
@miladm miladm self-assigned this Aug 19, 2022
@miladm miladm added the dynamism Dynamic Shape Features label Aug 19, 2022
@miladm miladm added this to the Dynamic Shape milestone Aug 19, 2022
@miladm
Copy link
Collaborator Author

miladm commented Aug 19, 2022

This PR shows that this implementation must be correct. However, there are other issues we want to resolves in the mentioned PR that may require extra time. This PR proposes the landing of the sym_sizes independently given all we know about its implementation methodology.

@Krovatkin Krovatkin force-pushed the sym_sizes_dynamic branch 2 times, most recently from ceb2954 to 149d61e Compare September 13, 2022 21:46
@Krovatkin
Copy link
Contributor

@JackCaoG I remember you had some doubts about this PR?

@Krovatkin
Copy link
Contributor

never mind, I'm seeing real failures now, need to investigate @JackCaoG

@JackCaoG
Copy link
Collaborator

seems OK from high level, I am not sure what does this pr enabled us through. is it .sym_size for XLATensorImpl?

return sizes_default();
}

c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering what the "custom" in sym_sizes_custom() implies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom means that a given TensorImpl can implement sym_sizes in whichever way they see fit. Note, sym_sizes_custom is a virtual method whereas sym_sizes() isn't. sym_sizes() checks if a given tensor impl implements CustomSizes policy and if so calls a virtual sym_sizes. This is kinda convoluted and was done to make sure we don't pay a virtual call penalty on the fast path: cpu and cuda tensors don't need the custom_sym_sizes implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

sym_sizes_.push_back(sn);
/*TODO(miladm): verify numel_ calculation after adding a dynamic op
*/
numel_ *= dynamic_cast<SizeNode*>(dim_node.get())->getStaticValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've read the definition of numel_ in the base class but it's not clear to me what it means. Do you know the meaning of this member?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of elements

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@Krovatkin
Copy link
Contributor

Krovatkin commented Sep 19, 2022

seems OK from high level, I am not sure what does this pr enabled us through. is it .sym_size for XLATensorImpl?

This PR finally opens the Pandora box, hehe. For ops like nonzero size() or shape in python will now return real sym ints and for any code in C++ that calls sym_sizes().

ASSERT_EQ(a.sym_sizes().at(0).is_symbolic(), false);
});
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a dynamic test? You can call non_zero but you would need to make sure https://github.com/pytorch/xla/blob/master/test/cpp/run_tests.sh#L11 is set so non_zero won't fallback to cpu.

Comment on lines 6 to 9
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/trie.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still needs these includes?

@Krovatkin Krovatkin force-pushed the sym_sizes_dynamic branch 2 times, most recently from 3954208 to 89062bd Compare September 26, 2022 18:02
sizes_and_strides_.set_sizes(updated_sizes);
auto updated_strides = torch::lazy::ComputeArrayStrides(
torch::lazy::ToVector<int64_t>(shape.get().dimensions()));
for (int i = 0; i < updated_strides.size(); i++) {
Copy link
Collaborator Author

@miladm miladm Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we move this for loop inside SetupSymSizeProperties()? This makes SetupSizeProperties and SetupSymSizeProperties implementations consistent. Is there a reason we want them called outside SetupSymSizeProperties?

@Krovatkin

Copy link
Contributor

@Krovatkin Krovatkin Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand your point, you are suggesting to duplicate the strides logic inside SetupSymSizeProperties as well, so it updates both sym_sizes and sym_strides the same way SetupSizeProperties updates sizes and strides?

If this is your concern, I believe it's fine for now, since sym_strides is still returning static strides now.
We should figure out if we need support for dynamic strides as well or whether it's okay for sym_strides to continue returning static strides.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang: based on our offline conversation about sym_strides under functionalization, it seems we assume no need for dynamic strides; correct?

sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
}
SetupSymSizeProperties();
generation_ = generation;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

xla_b = torch::nonzero(xla_b);
auto s0 = xla_b.sym_sizes().at(0);
ASSERT_EQ(s0.is_symbolic(), true);
auto sininode =
Copy link
Collaborator Author

@miladm miladm Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is the variable name meant to be sininode?

@Krovatkin

@JackCaoG JackCaoG merged commit 0efa9e4 into master Sep 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamism Dynamic Shape Features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants